"""
A module/class for creating LaTeX tables.

Original version written by Chris Burns:
https://users.obs.carnegiescience.edu/~cburns/site/?p=22
"""
import sys

import numpy as np
import pandas as pd
from scipy.stats import t

FLOAT_TYPES = [float, np.float16, np.float32, np.float64]

BRACKETS = {'(': ['(', ')'],
            '[': ['[', ']'],
            '{': ['{', '}'],
            'angled': ['$\\langle$', '$\\rangle$']}


def none_to_emptylist(a):
    return [] if a in [None, [None]] else a


class Table:
    """
    A module/class for creating LaTeX tables. In a nutshell, you create
    a table instance, add columns, set options, then call the print method.

    Parameters
    ----------
    numcols: integer
        Number of columns in the table.
    data_only: boolean, default False
        If True, write only data and skip table preambles and footers. As a result, output is not a displayable LaTeX
        table but can be used as a part of another table via \\input. Note that in this case, all the following kwargs
        that define table-level elements are redundant (except for justs that should match the justifications of the
        table where the data will be used).
    fontsize: string, optional
        Table font size. Eligible values:
        tiny, scriptsize, footnotesize, small, normalsize, large, Large, LARGE, huge, Huge.
    width: string, optional
        If width='auto', include a hack to the tex-code that tries to automatically choose
        optimal table width. Set to None if this option produces bad formatting.
    sideways: boolean, default False
        If True, rotate table 90 degrees.
    justs: string or list of strings, default 'l'
        Justification for table elements. Typical inputs: 'c', 'r', 'l', 'd'. Alternative inputs e.g. 'p{1.0cm}'.
        If string, all elements are justified based on input. If list, must have same length as numcols,
        and each col will be justified as specified in the associated element of the list.
    caption: string, optional
        Table caption.
    caption_position: {'top', 'bottom'}, default 'top'
        Table caption position (above or below the actual table).
    label: string, optional
        Table label.
    description: string, optional
        Typically a long multicolumn text in the beginning of the table that describes
        the content of the table.
        Tip: Can be set to refer to another file that contains the actual description (see example).
    footer: string, optional
        Table footer text.
    float_specifiers: string, optional
        Float placement specifiers. For example: 'h', 'ht'.

    Examples
    --------
    A = Table(2, width='auto', justs='c', caption='title', label='table_label',
              description='\\input{tables/table_desc.tex}', footer='footer text')
    A.add_panel(data)
    A.print_table(fp='output.tex')

    Required LaTeX packages:
    booktabs
    calc (if width='auto')
    rotating (if sideways=True)
    dcolumn (if 'd' is an argument in justs)

        Also define \\newcolumntype{d}{D{.}{.}{n}} immediately after
        \\usepackage{dcolumn}, where n is the number of decimals in the table (leave room for stars if present)

        Also define \\newcommand\\mc[1]{\\multicolumn{1}{r}{#1~~~~}}, where the number of ~ marks in the end equals n.

    If both bracket_marker='angled' and justs='d' are used, n = n + 1. Also define
    \\newcommand\\md[1]{\\multicolumn{1}{r}{#1}}.
    """

    def __init__(self, numcols, data_only=False, fontsize=None, width=None, sideways=False, justs='l', caption=None,
                 caption_position='top', label=None, description=None, footer=None, float_specifiers=None):

        self.numcols = numcols
        self.data_only = data_only
        self.justs = justs
        if len(justs) == 1:
            self.justs = [justs] * numcols
        else:
            self.justs = list(justs)

        assert len(self.justs) == numcols, f'Error, justs must have 1 or {numcols} elements'

        if fontsize is not None:
            fs = ['tiny', 'scriptsize', 'footnotesize', 'small', 'normalsize',
                  'large', 'Large', 'LARGE', 'huge', 'Huge']
            assert fontsize in fs, f'If defined, fontsize must be in {str(fs)}'

        assert caption_position in ['top', 'bottom'], "If defined, caption_position must be 'top' or 'bottom'"
        self.fontsize = fontsize
        self.width = width
        self.sideways = sideways
        self.caption = caption
        self.caption_position = caption_position
        self.label = label
        self.description = description
        self.footer = footer
        self.col_justs = []
        self.supheaders = {}
        self.supheader_ids = {}
        self.headers = {}
        self.header_ids = {}
        # self.data is a list of data. Each element of the list corresponds
        # to a separate "section" of the table, headed by self.data_labels
        # Each element of data should be a list of self.numcols items.
        self.data = {}
        self.formatting = {}
        self.nrows = {}
        self.regression = {}
        self.stars = {}
        self.dof = {}
        self.means = {}
        self.bracket_marker = {}
        self.skip_cols = {}
        self.npanels = 0
        self.float_specifiers = float_specifiers

    def add_panel(self, data, formatting="%.2f", supheaders=None, supheader_cols=None, headers=None, header_cols=None,
                  regression_specs=None):
        """
        Add a new panel and populate it with a matrix of data.

        Parameters
        ----------
        data: list or pandas.DataFrame
            List must have equal length to the number of columns of the table, with each
            item of being a list or numpy array. DataFrame must have same number of columns
            than the table.
        formatting: string or list, default "%.2f"
            Specifies the formatting for float-type data. If string, all elements are formatted
            as given in input. If list, must have same length as the number of columns in the table,
            and each column will be formatted as specified by the element of the list.
        supheaders: list, optional
            List of strings that will be in the supheader. An element that equals '' is not
            associated with a cline that separates supheader from header.
        supheader_cols: list, optional
            List of column indexes. If not specified, it is assumed the supheaders
            are in order and there are no multicolumns. If specified, you
            can indicate that the ith supheader spans several columns by setting the
            ith value of cols to a 2-tuple of first and last columns for the span.
        headers: list, optional
            List of strings that will be in the header.
        header_cols: list, optional
            List of column indexes. If not specified, it is assumed the headers
            are in order and there are no multicolumns. If specified, you
            can indicate that the ith header spans several columns by setting the
            ith value of cols to a 2-tuple of first and last columns for the span.
        regression_specs: dictionary, optional
            If defined, panel is interpreted as column-by-column regression coefficients. Every odd rows
            are interpreted as coefficients, and stars are added to the end to denote significance (optional).
            Even rows are interpreted as standard errors, and are enclosed in parentheses.
            Dictionary format:
                {stars: tuple, optional.
                    Elements define upper limits of significance at 1 star, 2 star... levels,
                 dof: integer or list of integers, optional.
                     Must be defined if stars is defined.
                     If integer, regressions in all columns are assumed to have the same degrees of freedom.
                     If list, must have same length as numcols, and each col will have degrees of freedom as specified
                     in the associated element of the list.
                 means: float or list, optional.
                     If defined, p-tests are conducted against null of estimates being equal to means (default=0).
                     If float, all coefficients are tested against this value. If list of floats, each parameter (row)
                     is tested against the associated element.
                 bracket_marker: string or list of strings, optional.
                     If defined, use alternative bracket markers for standard errors.
                     Possible values: '(', '[', '{', or 'angled'.
                     If list, must have same length as numcols. Each element defines what bracket marker to use
                     for that specific column.
                skip_cols: list of integers, optional
                    If defined, skip interpreting columns associated with elements of the list as regression
                    coefficients. If dof, means or bracket_marker are defined as lists, the elements associated with
                    skipped columns have no impact on anything.

                Example: regression_specs={stars: (0.1, 0.05, 0.01), dof: 100, means=None,
                                           bracket_marker='{', skip_cols=4}

        Examples
        --------
        add_panel(df, formatting=['%.2f', '%.3f'], supheaders=['', 'A', 'B'],
                  supheader_cols=[1, (2,4), (5,6)], headers=['(1)', '(2)', '(3)', '(4)', '(5)', '(6)'])
        """

        if isinstance(data, pd.DataFrame):
            data = data.values.T.tolist()

        assert isinstance(data, list), 'Data should be a list or Pandas DataFrame.'
        assert len(data) == self.numcols, 'Error, length of data must match number of table columns.'

        for datum in data:
            assert type(datum) in [list, np.ndarray], 'Data must be list of lists and numpy arrays.'

            assert len(np.shape(datum)) == 1, 'Data items must be vectors.'

        for header, header_cols in [[supheaders, supheader_cols], [headers, header_cols]]:
            if header is not None:
                assert isinstance(header, list), 'supheaders and headers must be lists.'
                if header_cols is None:
                    assert len(header) == self.numcols, \
                        f'supheaders and headers must be lists of length {self.numcols}'
                else:
                    assert isinstance(header_cols, list), 'supheader_cols and header_cols must be lists.'
                    assert len(header) == len(header_cols), \
                        '(sup)header and (sup)header_cols must have equal number of elements.'

        if regression_specs is not None:
            assert bool(np.shape(data[0])[0]+1 & 1), 'Number of data rows must be even if regression_specs is defined.'
            for key in ['stars', 'dof', 'means', 'bracket_marker', 'skip_cols']:
                if key not in regression_specs.keys():
                    regression_specs[key] = None
            if regression_specs['stars'] is not None:
                assert isinstance(regression_specs['stars'], tuple), 'stars in regression_specs must be a tuple.'
                assert regression_specs['dof'] is not None, 'dof must be defined in regression_specs if stars is.'
                assert isinstance(regression_specs['dof'], (int, list)), 'dof in regression_specs must be int or list.'
                if isinstance(regression_specs['dof'], list):
                    assert len(regression_specs['dof']) == self.numcols, \
                        f'dof in regression_specs must have length {self.numcols}'
                if regression_specs['means'] is not None:
                    assert isinstance(regression_specs['means'], (float, list)), \
                        'means in regression_specs must be float or list.'
                    if isinstance(regression_specs['means'], list):
                        assert len(regression_specs['means']) == np.shape(data[0])[0] / 2, \
                            f'means in regression_specs must have length {np.shape(data[0])[0] / 2}'
            if regression_specs['bracket_marker'] is not None:
                assert isinstance(regression_specs['bracket_marker'], (str, list)), \
                    'bracket_marker in regression_specs must be string or list.'
                array = regression_specs['bracket_marker']
                for val in [array] if isinstance(array, str) else array:
                    assert val in BRACKETS.keys(), \
                        f'bracket_marker in regression specs must be in {list(BRACKETS.keys())}'
                if isinstance(regression_specs['bracket_marker'], list):
                    assert len(regression_specs['bracket_marker']) == self.numcols, \
                        f'bracket_marker in regression_specs must have length {self.numcols}'
            if regression_specs['skip_cols'] is not None:
                assert isinstance(regression_specs['skip_cols'], list), 'skip_cols in regression_specs must be list.'

        self.data[self.npanels] = []
        self.formatting[self.npanels] = []
        self.nrows[self.npanels] = []
        self.headers[self.npanels] = {}
        self.header_ids[self.npanels] = {}
        self.bracket_marker[self.npanels] = []

        nrows = np.shape(data[0])[0]
        for datum in data[1:]:
            assert np.shape(datum)[0] == nrows, 'Each data item must have same first dimension.'
        self.nrows[self.npanels].append(nrows)
        if len(np.shape(formatting)) == 0:
            self.formatting[self.npanels].append([formatting] * self.numcols)
        else:
            assert len(np.shape(formatting)) == 1, 'formatting must be scalar or have same length as number of columns'
            self.formatting[self.npanels].append(formatting)

        self.data[self.npanels].append(data)

        for header, header_cols, i in [[supheaders, supheader_cols, 0], [headers, header_cols, 1]]:
            self.headers[self.npanels][i] = []
            self.header_ids[self.npanels][i] = []
            if header is not None:
                if header_cols is None:
                    self.headers[self.npanels][i].append(header)
                    self.header_ids[self.npanels][i].append(range(self.numcols))
                else:
                    ids = []
                    for item in header_cols:
                        if isinstance(item, int):
                            ids.append(item)
                        elif isinstance(item, tuple):
                            ids += range(item[0], item[1]+1)

                    #  ids.sort()
                    assert ids == list(range(1, self.numcols+1)), 'Missing or extra columns in cols'

                    self.headers[self.npanels][i].append(header)
                    self.header_ids[self.npanels][i].append(header_cols)

        if regression_specs is not None:
            self.regression[self.npanels] = True

            self.stars[self.npanels] = regression_specs['stars']
            self.dof[self.npanels] = regression_specs['dof']

            if isinstance(regression_specs['means'], float):
                self.means[self.npanels] = [regression_specs['means']] * (np.shape(data[0])[0] // 2)
            elif isinstance(regression_specs['means'], list):
                self.means[self.npanels] = regression_specs['means']
            else:
                self.means[self.npanels] = [0.0] * (np.shape(data[0])[0] // 2)

            if regression_specs['bracket_marker'] is None:
                self.bracket_marker[self.npanels].append(['('] * self.numcols)
            elif isinstance(regression_specs['bracket_marker'], str):
                self.bracket_marker[self.npanels].append([regression_specs['bracket_marker']] * self.numcols)
            else:
                self.bracket_marker[self.npanels].append(regression_specs['bracket_marker'])

            self.skip_cols[self.npanels] = none_to_emptylist(regression_specs['skip_cols'])

        else:
            self.regression[self.npanels] = None
            self.stars[self.npanels] = None

        self.npanels += 1

    def print_table(self, fp=None):
        """
        Prints the table.

        Parameters
        ----------
        fp: string, optional
            File path for the output. If None, prints to console.

        Examples
        --------
        print_table(fp='C:/Users/Documents.output.tex')
        """
        if fp is None:
            fp = sys.stdout
            we_open = False
        elif isinstance(fp, str):
            fp = open(fp, 'w')
            we_open = True
        else:
            we_open = False

        if not self.data_only:
            self._print_preamble(fp)
        for i in range(0, self.npanels):
            self._print_headers(fp, i)
            self._print_data(fp, i)
        if not self.data_only:
            self._print_footer(fp)
        if we_open:
            fp.close()

    def _print_preamble(self, fp):
        cols = "".join(self.justs)
        table_type = 'sidewaystable' if self.sideways else 'table'
        if self.float_specifiers:
            fp.write('\\begin{' + table_type + '}[' + self.float_specifiers + ']\n')
        else:
            fp.write('\\begin{' + table_type + '}\n')
        fp.write('\\centering\n')
        if self.fontsize:
            fp.write('\\%s\n' % self.fontsize)
        if self.caption and self.caption_position == 'top':
            fp.write('\\caption{%s}\n' % self.caption)
        if self.label:
            fp.write('\\label{table:%s}\n' % self.label)
        if self.width == 'auto':
            fp.write('\\setlength{\\linewidth}{.1cm}\\newcommand{\\contents}{\n')
        fp.write('\\begin{tabular}{%s}\n' % cols)

        fp.write('\\toprule\n')
        if self.description:
            fp.write('\\multicolumn{%d}{p{\\linewidth}}{%s} \\\\ \n' % (self.numcols, self.description))

    def _print_headers(self, fp, pn):
        if pn != 0 or self.description:
            fp.write('\\midrule\n')
        for i, supheaders in enumerate(self.headers[pn][0]):
            end = ['\\\\\n', ''][i == len(self.headers[pn][0])-1]
            for j, supheader in enumerate(supheaders):
                sep = [end, '&'][j < len(supheaders)-1]
                if len(np.shape(self.header_ids[pn][0][i][j])) == 1:
                    length = self.header_ids[pn][0][i][j][1] - self.header_ids[pn][0][i][j][0] + 1
                    fp.write('\\multicolumn{%d}{c}{%s} %s ' % (length, supheader, sep))
                elif self.justs[self.header_ids[pn][0][0][j] - 1] == 'd':
                    fp.write('\\multicolumn{1}{c}{%s} %s ' % (supheader, sep))
                else:
                    fp.write('%s %s ' % (supheader, sep))
        if self.headers[pn][0]:
            fp.write('\n\\\\ \n')

        for i, supheaders in enumerate(self.headers[pn][0]):
            for j, supheader in enumerate(supheaders):
                if supheader != '':
                    if len(np.shape(self.header_ids[pn][0][i][j])) == 1:
                        fp.write('\\cmidrule(lr){%d-%d}\n' % (self.header_ids[pn][0][i][j][0],
                                                              self.header_ids[pn][0][i][j][1]))
                    else:
                        fp.write('\\cmidrule(lr){%d-%d}\n' % (self.header_ids[pn][0][i][j],
                                                              self.header_ids[pn][0][i][j]))

        for i, headers in enumerate(self.headers[pn][1]):
            end = ['\\\\\n', ''][i == len(self.headers[pn][1])-1]
            for j, header in enumerate(headers):
                sep = [end, '&'][j < len(headers)-1]
                if len(np.shape(self.header_ids[pn][1][i][j])) == 1:
                    length = self.header_ids[pn][1][i][j][1] - self.header_ids[pn][1][i][j][0] + 1
                    fp.write('\\multicolumn{%d}{c}{%s} %s ' % (length, header, sep))
                elif self.justs[self.header_ids[pn][1][0][j]] == 'd':
                    fp.write('\\multicolumn{1}{c}{%s} %s ' % (header, sep))
                else:
                    fp.write('%s %s ' % (header, sep))
        if self.headers[pn][1]:
            fp.write('\n\\\\ \n')

    def _print_data(self, fp, pn):
        if self.headers[pn][0] or self.headers[pn][1]:
            fp.write('\\midrule\n')

        for i, data in enumerate(self.data[pn]):
            rows = []
            for j in range(np.shape(data[0])[0]):
                rows.append([])
                for k in range(len(data)):
                    if type(data[k][j]) in FLOAT_TYPES:
                        if np.isnan(data[k][j]):
                            rows[-1].append('\\ldots')
                        else:
                            if self.regression[pn] and (k + 1) not in self.skip_cols[pn]:
                                if bool(j+1 & 1):  # True if j is even
                                    if self.stars[pn]:
                                        if isinstance(self.dof[pn], int):
                                            dof = self.dof[pn]
                                        else:
                                            dof = self.dof[pn][k]
                                        cell = self._add_stars(data[k][j], data[k][j+1], self.formatting[pn][i][k],
                                                               self.stars[pn], dof, self.means[pn][j // 2])
                                        rows[-1].append(cell)
                                    else:
                                        rows[-1].append(str(self.formatting[pn][i][k] % data[k][j]))
                                else:
                                    left, right = BRACKETS[self.bracket_marker[pn][i][k]]
                                    if 'd' in self.justs and self.bracket_marker[pn][i][k] == 'angled':
                                        rows[-1].append('\\md{' + left + str(self.formatting[pn][i][k] % data[k][j])
                                                        + right + '}')
                                    else:
                                        rows[-1].append(left+str(self.formatting[pn][i][k] % data[k][j])+right)
                            else:
                                rows[-1].append(str(self.formatting[pn][i][k] % data[k][j]))
                    else:
                        if self.justs[k] == 'd':
                            rows[-1].append('\\mc{' + str(data[k][j]) + '}')
                        else:
                            rows[-1].append(str(data[k][j]))

            for row_i, row in enumerate(rows):
                fp.write(' & '.join(row))
                if self.stars[pn] and bool(row_i & 1):
                    fp.write(' \\\\[0.2cm]\n')
                else:
                    fp.write(' \\\\\n')

    def _print_footer(self, fp):
        fp.write('\\bottomrule\n')
        if self.footer:
            fp.write('\\multicolumn{%d}{p{\\linewidth}}{\\footnotesize %s}\n' % (self.numcols, self.footer))

        fp.write('\\end{tabular}\n')
        if self.caption and self.caption_position == 'bottom':
            fp.write('\\caption{%s}\n' % self.caption)
        if self.width == 'auto':
            txt = '}\n\\setbox0=\\hbox{\\contents}\n\\setlength{\\linewidth}{\\wd0-2\\tabcolsep-.25em}\n\\contents\n'
            fp.write(txt)

        table_type = 'sidewaystable' if self.sideways else 'table'
        fp.write('\\end{' + table_type + '}\n')

    @staticmethod
    def _add_stars(value, se, formatting, stars_tup, dof, mean):
        output = str(formatting % value)
        for n, i in enumerate(stars_tup):
            if t.cdf(abs((value-mean)/se), dof) > (1-i/2):
                output = str(formatting % value) + '\\textsuperscript{%s}' % ('*'*(n+1))

        return output
